{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "78c4e3a4",
   "metadata": {},
   "source": [
    "## DQN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ea551d48",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-11T02:04:56.287186Z",
     "start_time": "2024-05-11T02:04:51.246562Z"
    }
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "from IPython import display\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3408c33e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-11T02:04:58.165620Z",
     "start_time": "2024-05-11T02:04:58.157099Z"
    }
   },
   "outputs": [],
   "source": [
    "class GymHelper:\n",
    "    def __init__(self,env,figsize=(3,3)):\n",
    "        self.env=env\n",
    "        self.figsize=figsize\n",
    "        plt.figure(figsize=figsize)\n",
    "        self.img=plt.imshow(env.render())\n",
    "    def render(self,title=None):\n",
    "        img_data=self.env.render()\n",
    "        self.img.set_data(img_data)\n",
    "        display.display(plt.gcf())\n",
    "        display.clear_output(wait=True)\n",
    "        if title:\n",
    "            plt.title(title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dd08a077",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-11T02:05:02.879560Z",
     "start_time": "2024-05-11T02:05:00.054555Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADbCAYAAADNoUzuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZx0lEQVR4nO3dfVBU590+8GsXdpe33Q0vcY8rYNBArEFtxWhlrGBEjAmxTmeqralJZpyMNkLkp05Gm7ZiWoGaRtOOVaet1c70scSM0tjWUklFGodkohgi6lPTF+RNNojC7oLLLuzevz/yeJoVUBdhb5DrM3P+OPe59+z3HNhrz97nnF2NEEKAiEgCrewCiGjsYgARkTQMICKShgFERNIwgIhIGgYQEUnDACIiaRhARCQNA4iIpGEAkZ+3334bjz/+OMLDw6HRaLBs2TJoNJpBrevUqVPQaDQ4depUQI975JFHkJOTM6jnDMSVK1eg0Whw8ODBYX8u6l+o7AJo5Lh27RpWrVqFp556Cnv27IHBYIDVasXmzZsHtb6ZM2figw8+wNSpU4e4UnpQMIBI9emnn6Knpwff+c53kJGRobYnJiYOan0mkwlf/epXh6o8egDxIxgBAF588UXMmzcPALBixQpoNBpkZmaioKCgz0ewWx+RysrKMHPmTISHh2PKlCn4zW9+49evv49g//nPf/Ctb30LVqsVBoMBFosFCxcuRE1NTZ+a7rZ+ALDZbFizZg3i4+Oh1+uRlJSEbdu2obe316/f1atXsXz5chiNRpjNZqxYsQI2m22Qe4uGCo+ACADwgx/8ALNnz8a6detQWFiIBQsWwGQy4fDhw/32/+STT7Bx40Zs3rwZFosFv/71r7F69Wo8+uijmD9//oDP8/TTT8Pr9WLHjh1ITExEW1sbqqqq0NHREfD6bTYbZs+eDa1Wix/+8IeYPHkyPvjgA/z4xz/GlStXcODAAQCAy+VCVlYWrl69iqKiIqSkpODPf/4zVqxYMTQ7jwZPEP2fiooKAUC88847atvWrVvF7f8mEydOFGFhYaK+vl5tc7lcIiYmRqxZs6bP+ioqKoQQQrS1tQkA4q233rpjHfe6/jVr1oioqCi/fkII8dOf/lQAEBcvXhRCCLF3714BQLz77rt+/V566SUBQBw4cOCO9dDw4UcwGpQvf/nLfmNDYWFhSElJQX19/YCPiYmJweTJk/HGG29g586d+Pjjj+Hz+Qa9/j/96U9YsGABrFYrent71WnJkiUAgMrKSgBARUUFjEYjli5d6vccK1euDHzDaUgxgGhQYmNj+7QZDAa4XK4BH6PRaPC3v/0Nixcvxo4dOzBz5kw8/PDDeOWVV+B0OgNe/2effYY//vGP0Ol0ftPjjz8OAGhrawMAXL9+HRaLpc/6FEW5t42lYcMxIAqqiRMnYv/+/QA+P+t2+PBhFBQUwOPxYN++fQGtKy4uDtOnT8f27dv7XW61WgF8HmYfffRRn+UchJaPAUTSpKSk4Pvf/z6OHDmCc+fOBfz4nJwcHD9+HJMnT0Z0dPSA/RYsWIDDhw/j2LFjfh/DDh06NKi6aegwgChozp8/j9zcXHzzm99EcnIy9Ho9Tp48ifPnzw/qYsfXX38d5eXlSE9PxyuvvILHHnsM3d3duHLlCo4fP459+/YhPj4ezz//PHbt2oXnn38e27dvR3JyMo4fP46//vWvw7CVFAgGEAWNoiiYPHky9uzZg8bGRmg0GkyaNAlvvvkm8vLyAl7f+PHjcfbsWfzoRz/CG2+8gaamJhiNRiQlJeGpp55Sj4oiIiJw8uRJrF+/Hps3b4ZGo0F2djZKSkqQnp4+1JtJAdAIwV/FICI5eBaMiKRhABGRNAwgIpJGagDt2bMHSUlJCAsLQ1paGt5//32Z5RBRkEkLoLfffhv5+fl47bXX8PHHH+NrX/salixZgoaGBlklEVGQSTsLNmfOHMycORN79+5V2770pS9h2bJlKCoqklESEQWZlOuAPB4Pqqur+1x8lp2djaqqqj793W433G63Ou/z+XDjxg3ExsYO+utCiWj4CCHgdDphtVqh1Q78QUtKALW1tcHr9fa5QdBisfR7f05RURG2bdsWrPKIaIg0NjYiPj5+wOVSr4S+/ehFCNHvEc2WLVuwYcMGdd5utyMxMRGNjY0wmUzDXicRBcbhcCAhIQFGo/GO/aQEUFxcHEJCQvoc7bS2tvb7tQkGgwEGg6FPu8lkYgARjWB3GyKRchZMr9cjLS0N5eXlfu23biwkorFB2kewDRs2YNWqVZg1axbmzp2LX/7yl2hoaMDatWtllUREQSYtgFasWIHr16/j9ddfR0tLC1JTU3H8+HFMnDhRVklEFGSj8m54h8MBs9kMu93OMSCiEeheX6O8F4yIpGEAEZE0DCAikoYBRETSMICISBoGEBFJwwAiImkYQEQkDQOIiKRhABGRNAwgIpKGAURE0jCAiEgaBhARScMAIiJpGEBEJA0DiIikYQARkTQMICKShgFERNIwgIhIGgYQEUnDACIiaRhARCQNA4iIpGEAEZE0DCAikoYBRETSBBxAf//73/Hss8/CarVCo9HgD3/4g99yIQQKCgpgtVoRHh6OzMxMXLx40a+P2+1GXl4e4uLiEBkZiaVLl6Kpqem+NoSIRp+AA6irqwszZszA7t27+12+Y8cO7Ny5E7t378aZM2egKAoWLVoEp9Op9snPz0dpaSlKSkpw+vRpdHZ2IicnB16vd/BbQkSjj7gPAERpaak67/P5hKIoori4WG3r7u4WZrNZ7Nu3TwghREdHh9DpdKKkpETt09zcLLRarSgrK7un57Xb7QKAsNvt91M+EQ2Te32NDukYUF1dHWw2G7Kzs9U2g8GAjIwMVFVVAQCqq6vR09Pj18dqtSI1NVXtczu32w2Hw+E3EdHoN6QBZLPZAAAWi8Wv3WKxqMtsNhv0ej2io6MH7HO7oqIimM1mdUpISBjKsolIkmE5C6bRaPzmhRB92m53pz5btmyB3W5Xp8bGxiGrlYjkGdIAUhQFAPocybS2tqpHRYqiwOPxoL29fcA+tzMYDDCZTH4TEY1+QxpASUlJUBQF5eXlapvH40FlZSXS09MBAGlpadDpdH59WlpacOHCBbUPEY0NoYE+oLOzE//617/U+bq6OtTU1CAmJgaJiYnIz89HYWEhkpOTkZycjMLCQkRERGDlypUAALPZjNWrV2Pjxo2IjY1FTEwMNm3ahGnTpiErK2votoyIRr5AT69VVFQIAH2mF154QQjx+an4rVu3CkVRhMFgEPPnzxe1tbV+63C5XCI3N1fExMSI8PBwkZOTIxoaGu65Bp6GJxrZ7vU1qhFCCIn5NygOhwNmsxl2u53jQUQj0L2+RnkvGBFJwwAiImkYQEQkDQOIiKRhABGRNAwgIpKGAURE0jCAiEgaBhARScMAIiJpGEBEJA0DiIikYQARkTQMICKShgFERNIwgIhIGgYQEUnDACIiaRhARCQNA4iIpGEAEZE0Af8uGNFw8Xq60XrpFAzGOIQ9pEAf+RBCDJF3/VlvGr0YQDRi3LzeiKvVf4KvtwfaUD10kWY88rXvwDRhiuzSaJgwgGhEEELA1X4Vvl4PAMDX60avy4nQsEjJldFw4hgQjRj2hgt+86HhRhiMD0uqhoKBAUQjgtd9E92OVr824/gUaEP1kiqiYGAA0YjQc9MOt6PtCy0aRMQmQKPlv+iDjH9dGhG62uohfF51XqPVIsoySWJFFAwBBVBRURGeeOIJGI1GjBs3DsuWLcPly5f9+gghUFBQAKvVivDwcGRmZuLixYt+fdxuN/Ly8hAXF4fIyEgsXboUTU1N9781NCoJIXCzrQEQPrVNF/EQdJEPySuKgiKgAKqsrMS6devw4Ycfory8HL29vcjOzkZXV5faZ8eOHdi5cyd2796NM2fOQFEULFq0CE6nU+2Tn5+P0tJSlJSU4PTp0+js7EROTg68Xm9/T0sPOOHtgaPZ/43MYH4YunCjpIooaMR9aG1tFQBEZWWlEEIIn88nFEURxcXFap/u7m5hNpvFvn37hBBCdHR0CJ1OJ0pKStQ+zc3NQqvVirKysnt6XrvdLgAIu91+P+XTCOHq+EycO/j/xEf7XlKnq+eOC5/PJ7s0GqR7fY3e1xiQ3W4HAMTExAAA6urqYLPZkJ2drfYxGAzIyMhAVVUVAKC6uho9PT1+faxWK1JTU9U+t3O73XA4HH4TPTi6O1rQ292pzmu0oQiPjecV0GPAoANICIENGzZg3rx5SE1NBQDYbDYAgMVi8etrsVjUZTabDXq9HtHR0QP2uV1RURHMZrM6JSQkDLZsGoG6rtX7zYfowxARGy+pGgqmQQdQbm4uzp8/j9///vd9lt3+ziWEuOu72Z36bNmyBXa7XZ0aGxsHWzaNMMLn7RNABtPDCNFHSKqIgmlQAZSXl4djx46hoqIC8fH/fadSFAUA+hzJtLa2qkdFiqLA4/Ggvb19wD63MxgMMJlMfhM9GHpczj4BFBGXyAsQx4iAAkgIgdzcXBw9ehQnT55EUlKS3/KkpCQoioLy8nK1zePxoLKyEunp6QCAtLQ06HQ6vz4tLS24cOGC2ofGjp6bdvh6ur/QooFpwhSO/4wRAd2Mum7dOhw6dAjvvvsujEajeqRjNpsRHh4OjUaD/Px8FBYWIjk5GcnJySgsLERERARWrlyp9l29ejU2btyI2NhYxMTEYNOmTZg2bRqysrKGfgtpxBJCoNP2L/UGVADQaENgMMZKrIqCKaAA2rt3LwAgMzPTr/3AgQN48cUXAQCvvvoqXC4XXn75ZbS3t2POnDk4ceIEjMb/XtOxa9cuhIaGYvny5XC5XFi4cCEOHjyIkJCQ+9saGl2EQNe1K35NBvM4hD00Xk49FHQaIYSQXUSgHA4HzGYz7HY7x4NGMW+PG5dKC9Hd3qK2RU9Kw+SFL/EesFHuXl+j/CuTNN0dtttuQAWMyqMAx3/GDAYQSePpvAHh7VHntaF6RFomcwB6DGEAkRRCCDia/9evTRuqh543oI4pDCCSwtfrwc0bzX5tkeOSEMobUMcUBhBJ4XXfRPeNq35tETEToNHwX3Is4V+bpOhs/Q96PS6/tqjxKRz/GWMYQBR0QojPz375fQGZGeHRvP5nrGEAkQR9B6BDDZH8ArIxiAFEQdfb3QW345pfm2nCl6AJ4c/UjTUMIAo6t/M63M7r/23QaBAeOwEAx3/GGgYQBZ2z+X/9xn+0ITpE8QLEMYkBREElhA/dt338CjNbeAHiGMUAoqDy9XjQafu3X5s+KhpaXZikikgmBhAFlafrBnpcdr82U/zjkqoh2RhAFDRCCNy83gSv+6bapgkJ5fU/YxgDiILK0XTb9T9hRv4EzxjGAKKgEcLX5/qfyLhEhHD8Z8xiAFHQ9HR1wNXufwNq2EMWaLT8Kt6xigFEQdNtb4X3izegarQwWvkLGGMZA4iC4vNfwPgnhM+rtoXowxFmHiexKpKNAUTBIXxw3fb9P2Hmh6GPipFUEI0EDCAKCm+PGzfbGvzawswKx3/GOAYQBYWr/So8Nzv82ozjU+QUQyMGA4iCoru9BcLbq86H6MMROe4RDkCPcfwCFrpvvb296OzsvGOf61dq/eZFaBhcvVp4Ojr69NXr9YiIiBjKEmmEYgDRfTt//jyWLVsGn8/X7/IIgw6vP/9VTLJ8/o2HGg3wXtUn2L1lNvr7Wd7nnnsOP/nJT4axYhopGEB03zweD5qbmwcMIGucCd6wx/GRYxqMITcwOeIc/tnQiqbm5n77d/RzVEQPpoDGgPbu3Yvp06fDZDLBZDJh7ty5+Mtf/qIuF0KgoKAAVqsV4eHhyMzMxMWLF/3W4Xa7kZeXh7i4OERGRmLp0qVoamoamq2hEUmJfwL/9jyJGz0TUN+diovOdJz91Ca7LBoBAgqg+Ph4FBcX4+zZszh79iyefPJJfP3rX1dDZseOHdi5cyd2796NM2fOQFEULFq0CE6nU11Hfn4+SktLUVJSgtOnT6OzsxM5OTnwer0DPS2NcuOViYBW/39zGvz7mg4Nnzmk1kQjQ0AB9Oyzz+Lpp59GSkoKUlJSsH37dkRFReHDDz+EEAJvvfUWXnvtNXzjG99Aamoqfvvb3+LmzZs4dOgQAMBut2P//v148803kZWVha985Sv43e9+h9raWrz33nvDsoEkV4hWgyeneGHADQACWvRA03kO9i7XXR9LD75BjwF5vV6888476Orqwty5c1FXVwebzYbs7Gy1j8FgQEZGBqqqqrBmzRpUV1ejp6fHr4/VakVqaiqqqqqwePHigGr4xz/+gaioqMFuAg2Rurq6AZd5fQK7/qcUCRM+genhGZiRGIJ/fvrhgONFANDe3o5Lly4NR6kUJHc7K3pLwAFUW1uLuXPnoru7G1FRUSgtLcXUqVNRVVUFALBYLH79LRYL6uvrAQA2mw16vR7R0dF9+thsA48JuN1uuN1udd7h+Pzw3W63o7e3d6CHUZDc7Z+t6ZoDTddqANTgPa0Gor9TX1/g8Xg4ED3KdXV13VO/gAPoscceQ01NDTo6OnDkyBG88MILqKysVJfffmGZEOKuF5vdrU9RURG2bdvWp33OnDkwmUwBbgENNa323j/Je313SR98/oaUnp5+PyWRZLcOEu4m4Cuh9Xo9Hn30UcyaNQtFRUWYMWMGfvazn0FRFADocyTT2tqqHhUpigKPx4P29vYB+/Rny5YtsNvt6tTY2Bho2UQ0At33rRhCCLjdbiQlJUFRFJSXl6vLPB4PKisr1XeztLQ06HQ6vz4tLS24cOHCHd/xDAaDeur/1kREo19AH8G+973vYcmSJUhISIDT6URJSQlOnTqFsrIyaDQa5Ofno7CwEMnJyUhOTkZhYSEiIiKwcuVKAIDZbMbq1auxceNGxMbGIiYmBps2bcK0adOQlZU1LBtIRCNXQAH02WefYdWqVWhpaYHZbMb06dNRVlaGRYsWAQBeffVVuFwuvPzyy2hvb8ecOXNw4sQJGI1GdR27du1CaGgoli9fDpfLhYULF+LgwYMICeHXMoxWISEhMJlMdzyzFYiwMH5H9FihEeJu5yRGHofDAbPZDLvdzo9jI4Db7UZra+uQrS8yMhIxMfyistHsXl+jvBeM7pvBYEBCQoLsMmgU4vcBEZE0DCAikoYBRETSMICISBoGEBFJwwAiImkYQEQkDQOIiKRhABGRNAwgIpKGAURE0jCAiEgaBhARScMAIiJpGEBEJA0DiIikYQARkTQMICKShgFERNIwgIhIGgYQEUnDACIiaRhARCQNA4iIpGEAEZE0DCAikoYBRETSMICISBoGEBFJwwAiImlCZRcwGEIIAIDD4ZBcCRH159Zr89ZrdSCjMoCcTicAICEhQXIlRHQnTqcTZrN5wOUacbeIGoF8Ph8uX76MqVOnorGxESaTSXZJo4LD4UBCQgL3WYC43wInhIDT6YTVaoVWO/BIz6g8AtJqtZgwYQIAwGQy8Z8iQNxng8P9Fpg7HfncwkFoIpKGAURE0ozaADIYDNi6dSsMBoPsUkYN7rPB4X4bPqNyEJqIHgyj9giIiEY/BhARScMAIiJpGEBEJM2oDKA9e/YgKSkJYWFhSEtLw/vvvy+7JGmKiorwxBNPwGg0Yty4cVi2bBkuX77s10cIgYKCAlitVoSHhyMzMxMXL1706+N2u5GXl4e4uDhERkZi6dKlaGpqCuamSFNUVASNRoP8/Hy1jfssSMQoU1JSInQ6nfjVr34lLl26JNavXy8iIyNFfX297NKkWLx4sThw4IC4cOGCqKmpEc8884xITEwUnZ2dap/i4mJhNBrFkSNHRG1trVixYoUYP368cDgcap+1a9eKCRMmiPLycnHu3DmxYMECMWPGDNHb2ytjs4Lmo48+Eo888oiYPn26WL9+vdrOfRYcoy6AZs+eLdauXevXNmXKFLF582ZJFY0sra2tAoCorKwUQgjh8/mEoiiiuLhY7dPd3S3MZrPYt2+fEEKIjo4OodPpRElJidqnublZaLVaUVZWFtwNCCKn0ymSk5NFeXm5yMjIUAOI+yx4RtVHMI/Hg+rqamRnZ/u1Z2dno6qqSlJVI4vdbgcAxMTEAADq6upgs9n89pnBYEBGRoa6z6qrq9HT0+PXx2q1IjU19YHer+vWrcMzzzyDrKwsv3bus+AZVTejtrW1wev1wmKx+LVbLBbYbDZJVY0cQghs2LAB8+bNQ2pqKgCo+6W/fVZfX6/20ev1iI6O7tPnQd2vJSUlOHfuHM6cOdNnGfdZ8IyqALpFo9H4zQsh+rSNRbm5uTh//jxOnz7dZ9lg9tmDul8bGxuxfv16nDhxAmFhYQP24z4bfqPqI1hcXBxCQkL6vMO0trb2ebcaa/Ly8nDs2DFUVFQgPj5ebVcUBQDuuM8URYHH40F7e/uAfR4k1dXVaG1tRVpaGkJDQxEaGorKykr8/Oc/R2hoqLrN3GfDb1QFkF6vR1paGsrLy/3ay8vLkZ6eLqkquYQQyM3NxdGjR3Hy5EkkJSX5LU9KSoKiKH77zOPxoLKyUt1naWlp0Ol0fn1aWlpw4cKFB3K/Lly4ELW1taipqVGnWbNm4bnnnkNNTQ0mTZrEfRYsEgfAB+XWafj9+/eLS5cuifz8fBEZGSmuXLkiuzQpvvvd7wqz2SxOnTolWlpa1OnmzZtqn+LiYmE2m8XRo0dFbW2t+Pa3v93vKeX4+Hjx3nvviXPnzoknn3xyTJ1S/uJZMCG4z4Jl1AWQEEL84he/EBMnThR6vV7MnDlTPeU8FgHodzpw4IDax+fzia1btwpFUYTBYBDz588XtbW1futxuVwiNzdXxMTEiPDwcJGTkyMaGhqCvDXy3B5A3GfBwa/jICJpRtUYEBE9WBhARCQNA4iIpGEAEZE0DCAikoYBRETSMICISBoGEBFJwwAiImkYQEQkDQOIiKRhABGRNP8fvHjc6aTtvSYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "env=gym.make(\"CartPole-v1\",render_mode=\"rgb_array\")\n",
    "env.reset()\n",
    "gym_helper=GymHelper(env)\n",
    "i=0\n",
    "while True:\n",
    "    gym_helper.render(title=str(i))\n",
    "    action=env.action_space.sample()\n",
    "    observation,reward,terminated,truncated,info=env.step(action)\n",
    "    done=terminated or truncated\n",
    "    i+=1\n",
    "    if done:\n",
    "        break\n",
    "gym_helper.render(\"finished\")\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ea6bec31",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-11T02:05:09.397313Z",
     "start_time": "2024-05-11T02:05:04.890313Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from tqdm import *\n",
    "import collections\n",
    "import time\n",
    "import random\n",
    "import sys\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "71ff1b51",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-11T02:05:11.280283Z",
     "start_time": "2024-05-11T02:05:11.273765Z"
    }
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self,input_dim,output_dim):\n",
    "        super(Net,self).__init__()\n",
    "        self.input_dim=input_dim\n",
    "        self.output_dim=output_dim\n",
    "        self.fc=nn.Sequential(\n",
    "            nn.Linear(self.input_dim,64),\n",
    "            nn.ReLU(),\n",
    "            nn.linear(64,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,self.output_dim)\n",
    "        )\n",
    "    def forward(self,state):\n",
    "        action=self.fc(state)\n",
    "        return action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3df5e90e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-11T02:05:18.710865Z",
     "start_time": "2024-05-11T02:05:18.704334Z"
    }
   },
   "outputs": [],
   "source": [
    "class Replay_buffer:\n",
    "    def __init__(self,max_size):\n",
    "        self.max_size=max_size\n",
    "        self.buffer=collections.deque(maxlen=self.max_size)\n",
    "    def add(self,state,action,reward,next_state,done):\n",
    "        experience=(state,action,reward,next_state,done)\n",
    "        self.buffer.append(experience)\n",
    "    def sample(self,batch_size):\n",
    "        batch=random.sample(self.buffer,batch_size)\n",
    "        state,action,reward,next_state,done=zip(*batch)\n",
    "        return state,action,reward,next_state,done\n",
    "    def __len__(self):\n",
    "        return len(self.buffer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aaaa35d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DQN:\n",
    "    def __init__(self,env,lr=0.001,gamma=0.99,buffer_size=10000,T=10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
